CREATE OR REPLACE FUNCTION "pipeline"."recommendPretreat"("inputConfig" text)
  RETURNS "pg_catalog"."text" AS $BODY$
    import json
    import itertools
    import time
    import sys
    reload(sys)
    sys.setdefaultencoding("utf8")
    
    def getTableMetaInfo(table_name): 
        tmps = table_name.split(".")
        if len(tmps) > 0:
            table_name = tmps[1]	
        sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
        rs = plpy.execute(sql_str)
        meta = {}
        for line in rs:
            meta[line['column_name']] = line['data_type']
        return meta
        
    def getBins(sourceTable, col, binCount):
        sourceTable = decode_utf8(sourceTable)
        col = decode_utf8(col)
        bins = []
        msg = "success"
        sqlText = "select {} from {} ORDER BY {} {} LIMIT 1"
        try:
            meta = getTableMetaInfo(sourceTable)
            isInt = False
            if meta.get(col) == "bigint" or meta.get(col) == "tinyint" or meta.get(col) == "integer":
                isInt = True
            rs = plpy.execute(sqlText.format(col, sourceTable, col, 'DESC'))
            maxValue = 0
            if isInt:
                maxValue = int(rs[0][encode_utf8(col)])
            else:
                maxValue = float(rs[0][encode_utf8(col)])
            rs = plpy.execute(sqlText.format(col, sourceTable, col, 'ASC'))
            minValue = 0 
            if isInt:
                minValue = int(rs[0][encode_utf8(col)])
            else:
                minValue = float(rs[0][encode_utf8(col)])
            step = 0
            if isInt:
                step = (maxValue - minValue) / binCount
            else:
                step = float(maxValue - minValue) / binCount
            for i in range(binCount):
                left = minValue + step * i 
                right = minValue + step * (i + 1)
                if i == binCount - 1:
                    right = maxValue +  1
                bins.append([left, right])
        except Exception as e:
            msg = str(e)
            return None, msg
        
        return bins, msg
            
    def getCategory(sourceTable, col):
        categories = []
        msg = "success"
        sqlText = "SELECT DISTINCT({}) as {} from {}".format(decode_utf8(col), decode_utf8(col), decode_utf8(sourceTable))
        try:
            rs = plpy.execute(sqlText)
            for line in rs:
                categories.append(line[decode_utf8(col)])
        except Exception as e:
            msg = str(e)
            return None, msg
           
        return categories, msg
    
    def isInteger(col, meta):
        isInt = False
        col = decode_utf8(col)
        if meta.get(col) == "bigint" or meta.get(col) == "tinyint" or meta.get(col) == "integer":
            isInt = True
        return isInt
    
    def getProductResult(*args):
        li = []
        for i in itertools.product(*args):
            li.append(list(i))
        return li
      
    def validate(jsonStr):
        msg = "success"
        jsonObj = None
        try:
            jsonObj = json.loads(jsonStr)
            if not jsonObj.has_key("datasetId") or not jsonObj.has_key("transformMethod") or not jsonObj.has_key("transformConfig"):
                msg = "datasetId or transformMethod or transformConfig is not exist"
                return None, msg
            transformConfig = jsonObj.get("transformConfig")
            if "binFields" not in transformConfig or "binMethod" not in transformConfig:
                msg = "binFields or binMethod not exist"
                return None, msg
            binMethod = transformConfig.get("binMethod")
            if binMethod != "count" and "targetField" not in transformConfig:
                msg = "targetField not exist"
                return None, msg
        except Exception as e:
            msg = str(e)
            return None, msg
        return jsonObj, msg
    
    def encode_utf8(s):
        return s.encode("utf-8")
       
    def decode_utf8(s):
        return s.decode("utf-8")
    
    def logError(msg, status, result):
        result["status"] = status
        result["error_msg"] = msg
        return json.dumps(result)
    
    def begin(inputConfig):
        result = {
            "status": 0,
            "error_msg": "success",
            "result": {
                "input_params": {},
                "output_params": []
            }
        }
        configJson, msg = validate(inputConfig)
        if configJson == None:
            result["status"] = 500
            result["error_msg"] = msg
            return json.dumps(result)
        result["result"]["input_params"] = configJson
        transformConfig = configJson.get("transformConfig")
        binFields = transformConfig.get("binFields")
        productItems = []
        dataTypes = []
        cols = []
        tableName = configJson.get("datasetId")
        for binField in binFields:
            if "id" not in binField or "type" not in binField:
                return logError("id or type not in binField", 500, result)
            cols.append(decode_utf8(binField.get("id")))
            dataType = binField.get("type")
            if dataType == "quantitative":
                binCount = int(binField.get("binCount"))
                product, msg = getBins(tableName, decode_utf8(binField.get("id")), binCount)
                if product == None:
                    return logError(msg, 500, result)
                productItems.append(product)
                dataTypes.append(dataType)
            elif dataType == "categorical":
                #product, msg = getCategory(tableName, binField.get("id"))
                product, msg = [], "success"
                if product == None:
                    return logError(msg, 500, result)
                productItems.append(product)
                dataTypes.append(dataType)
            else:
                msg = "{} not support type".format(dataType)
                return logError(msg, 500, result)
        
        if len(dataTypes) != len(cols) and len(productItems) != len(cols):
            return logError("inner error", 500, result)
        binMethod = transformConfig.get("binMethod")
        meta = getTableMetaInfo(tableName)
        result["result"]["input_params"]["productItems"] = productItems
        """
        select name, "pipeline"."getBinInt"(a2, '[[1, 40], [40, 1000]]') as cate, count(1) as count from dataset.wyz_testdata GROUP BY (name, cate);
        """
        outPutTable = "pipeline.solid_rec_{}".format(int(time.time() * 100000))
        if binMethod == "count":
            sqlText = """CREATE TABLE {} AS SELECT {}, count(1) as count from {} GROUP BY {}"""
            columnBody = []
            groupByBody = []
            colsAfter = []
            for i in range(len(cols)):
                dataType = dataTypes[i]
                product = productItems[i]
                colName = cols[i].decode("utf-8")
                if dataType == "categorical":
                    columnBody.append(colName)
                    colsAfter.append(colName)
                else:
                    if isInteger(colName, meta):
                        columnBody.append("\"pipeline\".\"getBinInt\"({}, '{}') as {}".format(colName, json.dumps(product), colName + "_after"))
                    else:
                        columnBody.append("\"pipeline\".\"getBinFloat\"({}, '{}') as {}".format(colName, json.dumps(product), colName + "_after"))
                    colsAfter.append(colName + "_after")
            sqlText = sqlText.format(outPutTable, ",".join(columnBody), tableName, "({})".format(",".join(colsAfter)))
            try:
                rs = plpy.execute(sqlText)
                outputCols = colsAfter
                outputCols.append("count")
                tmpItem = {
                    "out_table_name": outPutTable,
                    "output_cols": outputCols
                }
                result["result"]["output_params"].append(tmpItem)
            except Exception as e:
                dropSql = "DROP TABLE IF EXISTS {}".format(outPutTable)
                plpy.execute(dropSql)
                msg = str(e)
                return logError(msg, 500, result)
        elif binMethod == "average":
            targetField = transformConfig.get("targetField")
            targetCol = decode_utf8(targetField.get("id"))
            sqlText = """CREATE TABLE {} AS SELECT {}, AVG({}) as avg_{} from {} GROUP BY {}"""
            columnBody = []
            groupByBody = []
            colsAfter = []
            for i in range(len(cols)):
                dataType = dataTypes[i]
                product = productItems[i]
                colName = cols[i].decode("utf-8")
                if dataType == "categorical":
                    columnBody.append(colName)
                    colsAfter.append(colName)
                else:
                    if isInteger(colName, meta):
                        columnBody.append("\"pipeline\".\"getBinInt\"({}, '{}') as {}".format(colName, json.dumps(product), colName + "_after"))
                    else:
                        columnBody.append("\"pipeline\".\"getBinFloat\"({}, '{}') as {}".format(colName, json.dumps(product), colName + "_after"))
                    colsAfter.append(colName + "_after")
            sqlText = sqlText.format(outPutTable, ",".join(columnBody), targetCol, targetCol, tableName, "({})".format(",".join(colsAfter)))
            try:
                rs = plpy.execute(sqlText)
                outputCols = colsAfter
                outputCols.append("avg_{}".format(targetCol))
                tmpItem = {
                    "out_table_name": outPutTable,
                    "output_cols": outputCols
                }
                result["result"]["output_params"].append(tmpItem)
            except Exception as e:
                plpy.execute("DROP TABLE IF EXISTS {}".format(outPutTable))
                msg = str(e)
                return logError(msg, 500, result)
            print(1)
        else:
            return logError("{} not support".format(binMethod), 500, result)    
        return json.dumps(result)
    if __name__ == "__main__":
        return begin(inputConfig)
    
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100